/*
 * Copyright 2023 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <unordered_map>

#include "ir/subtype-exprs.h"
#include "ir/subtypes.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
#include "support/unique_deferring_queue.h"
#include "wasm-traversal.h"
#include "wasm-type.h"
#include "wasm.h"

// Compute and use the minimal subtype relation required to maintain module
// validity and behavior. This minimal relation will be a subset of the original
// subtype relation. Start by walking the IR and collecting pairs of types that
// need to be in the subtype relation for each expression to validate. For
// example, a local.set requires that the type of its operand be a subtype of
// the local's type. Casts do not generate subtypings at this point because it
// is not necessary for the cast target to be a subtype of the cast source for
// the cast to validate.
//
// From that initial subtype relation, we then start finding new subtypings that
// are required by the subtypings we have found already. These transitively
// required subtypings come from two sources.
//
// The first source is type definitions. Consider these type definitions:
//
//   (type $A (sub (struct (ref $X))))
//   (type $B (sub $A (struct (ref $Y))))
//
// If we have determined that $B must remain a subtype of $A, then we know that
// $Y must remain a subtype of $X as well, since the type definitions would not
// be valid otherwise. Similarly, knowing that $X must remain a subtype of $Y
// may transitively require other subtypings as well based on their type
// definitions.
//
// The second source of transitive subtyping requirements is casts. Although
// casting from one type to another does not necessarily require that those
// types are related, we do need to make sure that we do not change the
// behavior of casts by removing subtype relationships they might observe. For
// example, consider this module:
//
// (module
//  ;; original subtyping: $bot <: $mid <: $top
//  (type $top (sub (struct)))
//  (type $mid (sub $top (struct)))
//  (type $bot (sub $mid (struct)))
//
//  (func $f
//   (local $top (ref $top))
//   (local $mid (ref $mid))
//
//   ;; Requires $bot <: $top
//   (local.set $top (struct.new $bot))
//
//   ;; Cast $top to $mid
//   (local.set $mid (ref.cast (ref $mid) (local.get $top)))
//  )
// )
//
// The only subtype relation directly required by the IR for this module is $bot
// <: $top. However, if we optimized the module so that $bot <: $top was the
// only subtype relation, we would change the behavior of the cast. In the
// original module, a value of type (ref $bot) is cast to (ref $mid). The cast
// succeeds because in the original module, $bot <: $mid. If we optimize so that
// we have $bot <: $top and no other subtypings, though, the cast will fail
// because the value of type (ref $bot) no longer inhabits (ref $mid). To
// prevent the cast's behavior from changing, we need to ensure that $bot <:
// $mid.
//
// The set of subtyping requirements generated by a cast from $src to $dest is
// that for every known remaining subtype $v of $src, if $v <: $dest in the
// original module, then $v <: $dest in the optimized module. In other words,
// for every type $v of values we know can flow into the cast, if the cast would
// have succeeded for values of type $v before, then we know the cast must
// continue to succeed for values of type $v. These requirements arising from
// casts can also generate transitive requirements because we learn about new
// types of values that can flow into casts as we learn about new subtypes of
// cast sources.
//
// Starting with the initial subtype relation determined by walking the IR,
// repeatedly search for new subtypings by analyzing type definitions and casts
// in lock step until we reach a fixed point. This is the minimal subtype
// relation that preserves module validity and behavior that can be found
// without a more precise analysis of types that might flow into each cast.

namespace wasm {

namespace {

struct Unsubtyping
  : WalkerPass<
      ControlFlowWalker<Unsubtyping, SubtypingDiscoverer<Unsubtyping>>> {
  // The new set of supertype relations.
  std::unordered_map<HeapType, HeapType> supertypes;

  // Map from cast source types to their destinations.
  std::unordered_map<HeapType, std::unordered_set<HeapType>> castTypes;

  // The set of subtypes that need to have their type definitions analyzed to
  // transitively find other subtype relations they depend on. We add to it
  // every time we find a new subtype relationship we need to keep.
  UniqueDeferredQueue<HeapType> work;

  void run(Module* wasm) override {
    if (!wasm->features.hasGC()) {
      return;
    }
    analyzePublicTypes(*wasm);
    walkModule(wasm);
    analyzeTransitiveDependencies();
    optimizeTypes(*wasm);
    // Cast types may be refinable if their source and target types are no
    // longer related. TODO: Experiment with running this only after checking
    // whether it is necessary.
    ReFinalize().run(getPassRunner(), wasm);
  }

  // Note that sub must remain a subtype of super.
  void noteSubtype(HeapType sub, HeapType super) {
    if (sub == super || sub.isBottom() || super.isBottom()) {
      return;
    }

    auto [it, inserted] = supertypes.insert({sub, super});
    if (inserted) {
      work.push(sub);
      // TODO: Incrementally check all subtypes (inclusive) of sub against super
      // and all its supertypes if we have already analyzed casts.
      return;
    }
    // We already had a recorded supertype. The new supertype might be deeper,
    // shallower, or identical to the old supertype.
    auto oldSuper = it->second;
    if (super == oldSuper) {
      return;
    }
    // There are two different supertypes, but each type can only have a single
    // direct subtype so the supertype chain cannot fork and one of the
    // supertypes must be a supertype of the other. Recursively record that
    // relationship as well.
    if (HeapType::isSubType(super, oldSuper)) {
      // sub <: super <: oldSuper
      it->second = super;
      work.push(sub);
      // TODO: Incrementally check all subtypes (inclusive) of sub against super
      // if we have already analyzed casts.
      noteSubtype(super, oldSuper);
    } else {
      // sub <: oldSuper <: super
      noteSubtype(oldSuper, super);
    }
  }

  void noteSubtype(Type sub, Type super) {
    if (sub.isTuple()) {
      assert(super.isTuple() && sub.size() == super.size());
      for (size_t i = 0, size = sub.size(); i < size; ++i) {
        noteSubtype(sub[i], super[i]);
      }
      return;
    }
    if (!sub.isRef() || !super.isRef()) {
      return;
    }
    noteSubtype(sub.getHeapType(), super.getHeapType());
  }

  // Note a subtyping where one or both sides are expressions.
  void noteSubtype(Expression* sub, Type super) {
    noteSubtype(sub->type, super);
  }
  void noteSubtype(Type sub, Expression* super) {
    noteSubtype(sub, super->type);
  }
  void noteSubtype(Expression* sub, Expression* super) {
    noteSubtype(sub->type, super->type);
  }

  void noteCast(HeapType src, HeapType dest) {
    if (src == dest || dest.isBottom()) {
      return;
    }
    assert(HeapType::isSubType(dest, src));
    castTypes[src].insert(dest);
  }

  void noteCast(Type src, Type dest) {
    assert(!src.isTuple() && !dest.isTuple());
    if (src == Type::unreachable) {
      return;
    }
    assert(src.isRef() && dest.isRef());
    noteCast(src.getHeapType(), dest.getHeapType());
  }

  // Note a cast where one or both sides are expressions.
  void noteCast(Expression* src, Type dest) { noteCast(src->type, dest); }
  void noteCast(Expression* src, Expression* dest) {
    noteCast(src->type, dest->type);
  }

  void analyzePublicTypes(Module& wasm) {
    // We cannot change supertypes for anything public.
    for (auto type : ModuleUtils::getPublicHeapTypes(wasm)) {
      if (auto super = type.getDeclaredSuperType()) {
        noteSubtype(type, *super);
      }
    }
  }

  void analyzeTransitiveDependencies() {
    // While we have found new subtypings and have not reached a fixed point...
    while (!work.empty()) {
      // Subtype relationships that we are keeping might depend on other subtype
      // relationships that we are not yet planning to keep. Transitively find
      // all the relationships we need to keep all our type definitions valid.
      while (!work.empty()) {
        auto type = work.pop();
        auto super = supertypes.at(type);
        if (super.isBasic()) {
          continue;
        }
        if (type.isStruct()) {
          const auto& fields = type.getStruct().fields;
          const auto& superFields = super.getStruct().fields;
          for (size_t i = 0, size = superFields.size(); i < size; ++i) {
            noteSubtype(fields[i].type, superFields[i].type);
          }
        } else if (type.isArray()) {
          auto elem = type.getArray().element;
          noteSubtype(elem.type, super.getArray().element.type);
        } else {
          assert(type.isSignature());
          auto sig = type.getSignature();
          auto superSig = super.getSignature();
          noteSubtype(superSig.params, sig.params);
          noteSubtype(sig.results, superSig.results);
        }
      }

      // Analyze all casts at once.
      // TODO: This is expensive. Analyze casts incrementally after we
      // initially analyze them.
      analyzeCasts();
    }
  }

  void analyzeCasts() {
    // For each cast (src, dest) pair, any type that remains a subtype of src
    // (meaning its values can inhabit locations typed src) and that was
    // originally a subtype of dest (meaning its values would have passed the
    // cast) should remain a subtype of dest so that its values continue to pass
    // the cast.
    //
    // For every type, walk up its new supertype chain to find cast sources and
    // compare against their associated cast destinations.
    for (auto it = supertypes.begin(); it != supertypes.end(); ++it) {
      auto type = it->first;
      for (auto srcIt = it; srcIt != supertypes.end();
           srcIt = supertypes.find(srcIt->second)) {
        auto src = srcIt->second;
        auto destsIt = castTypes.find(src);
        if (destsIt == castTypes.end()) {
          continue;
        }
        for (auto dest : destsIt->second) {
          if (HeapType::isSubType(type, dest)) {
            noteSubtype(type, dest);
          }
        }
      }
    }
  }

  void optimizeTypes(Module& wasm) {
    struct Rewriter : GlobalTypeRewriter {
      Unsubtyping& parent;
      Rewriter(Unsubtyping& parent, Module& wasm)
        : GlobalTypeRewriter(wasm), parent(parent) {}
      std::optional<HeapType> getDeclaredSuperType(HeapType type) override {
        if (auto it = parent.supertypes.find(type);
            it != parent.supertypes.end() && !it->second.isBasic()) {
          return it->second;
        }
        return std::nullopt;
      }
    };
    Rewriter(*this, wasm).update();
  }

  void doWalkModule(Module* wasm) {
    // Visit the functions in parallel, filling in `supertypes` and `castTypes`
    // on separate instances which will later be merged.
    ModuleUtils::ParallelFunctionAnalysis<Unsubtyping> analysis(
      *wasm, [&](Function* func, Unsubtyping& unsubtyping) {
        if (!func->imported()) {
          unsubtyping.walkFunctionInModule(func, wasm);
        }
      });
    // Collect the results from the functions.
    for (auto& [_, unsubtyping] : analysis.map) {
      for (auto [sub, super] : unsubtyping.supertypes) {
        noteSubtype(sub, super);
      }
      for (auto& [src, dests] : unsubtyping.castTypes) {
        for (auto dest : dests) {
          noteCast(src, dest);
        }
      }
    }
    // Collect constraints from top-level items.
    for (auto& global : wasm->globals) {
      visitGlobal(global.get());
    }
    for (auto& seg : wasm->elementSegments) {
      visitElementSegment(seg.get());
    }
    // Visit the rest of the code that is not in functions.
    walkModuleCode(wasm);
  }
};

} // anonymous namespace

Pass* createUnsubtypingPass() { return new Unsubtyping(); }

} // namespace wasm
